{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 量化 QPartitionExpr\n",
    "\n",
    "下面以表达式 $f(x, y) = (x + y)(x -y)$ 为例展示。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import set_env"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #A2F\">@main</span>(<span style=\"color: #A2F; font-weight: bold\">%</span>x: Tensor[(<span style=\"color: #008000\">10</span>), float32], <span style=\"color: #A2F; font-weight: bold\">%</span>y: Tensor[(<span style=\"color: #008000\">10</span>), float32]) {\n",
       "  <span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">0</span> <span style=\"color: #A2F; font-weight: bold\">=</span> add(<span style=\"color: #A2F; font-weight: bold\">%</span>x, <span style=\"color: #A2F; font-weight: bold\">%</span>y);\n",
       "  <span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">1</span> <span style=\"color: #A2F; font-weight: bold\">=</span> subtract(<span style=\"color: #A2F; font-weight: bold\">%</span>x, <span style=\"color: #A2F; font-weight: bold\">%</span>y);\n",
       "  <span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">2</span> <span style=\"color: #A2F; font-weight: bold\">=</span> multiply(<span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">0</span>, <span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">1</span>);\n",
       "  exp(<span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">2</span>)\n",
       "}\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import tvm\n",
    "from tvm import relay\n",
    "\n",
    "x = relay.var(\"x\", dtype=\"float32\", shape=(10,))\n",
    "y = relay.var(\"y\", dtype=\"float32\", shape=(10,))\n",
    "z1 = x + y\n",
    "z2 = x - y\n",
    "z3 = z1 * z2\n",
    "z4 = relay.exp(z3)\n",
    "mod = tvm.IRModule.from_expr(z4)\n",
    "mod.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 自定义分区"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #A2F\">@main</span>(<span style=\"color: #A2F; font-weight: bold\">%</span>x: Tensor[(<span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">*/</span>, <span style=\"color: #A2F; font-weight: bold\">%</span>y: Tensor[(<span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">*/</span>) <span style=\"color: #A2F; font-weight: bold\">-&gt;</span> Tensor[(<span style=\"color: #008000\">10</span>), float32] {\n",
       "  <span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">2</span> <span style=\"color: #A2F; font-weight: bold\">=</span> fn (<span style=\"color: #A2F; font-weight: bold\">%</span>FunctionVar_0_0: Tensor[(<span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">*/</span>, <span style=\"color: #A2F; font-weight: bold\">%</span>FunctionVar_0_1: Tensor[(<span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">*/</span>, PartitionedFromPattern<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;add_subtract_multiply_&quot;</span>, Composite<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;ccompiler.add_subtract_multiply&quot;</span>) <span style=\"color: #A2F; font-weight: bold\">-&gt;</span> Tensor[(<span style=\"color: #008000\">10</span>), float32] {\n",
       "    <span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">0</span> <span style=\"color: #A2F; font-weight: bold\">=</span> add(<span style=\"color: #A2F; font-weight: bold\">%</span>FunctionVar_0_0, <span style=\"color: #A2F; font-weight: bold\">%</span>FunctionVar_0_1) <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">*/</span>;\n",
       "    <span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">1</span> <span style=\"color: #A2F; font-weight: bold\">=</span> subtract(<span style=\"color: #A2F; font-weight: bold\">%</span>FunctionVar_0_0, <span style=\"color: #A2F; font-weight: bold\">%</span>FunctionVar_0_1) <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">*/</span>;\n",
       "    multiply(<span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">0</span>, <span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">1</span>) <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">*/</span>\n",
       "  } <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>fn (Tensor[(<span style=\"color: #008000\">10</span>), float32], Tensor[(<span style=\"color: #008000\">10</span>), float32]) <span style=\"color: #A2F; font-weight: bold\">-&gt;</span> Tensor[(<span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">*/</span>;\n",
       "  <span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">3</span> <span style=\"color: #A2F; font-weight: bold\">=</span> <span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">2</span>(<span style=\"color: #A2F; font-weight: bold\">%</span>x, <span style=\"color: #A2F; font-weight: bold\">%</span>y) <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">*/</span>;\n",
       "  exp(<span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">3</span>) <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">*/</span>\n",
       "}\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from tvm.relay.quantize._partition import (\n",
    "    register_partition_function,\n",
    "    QPartitionExpr,\n",
    "    partition_expr_check\n",
    ")\n",
    "from tvm.relay.dataflow_pattern import is_constant, is_op, wildcard, is_var\n",
    "from tvm.relay import Call\n",
    "from tvm.relay.function import Function, FunctionWithFields\n",
    "@tvm.relay.transform.function_pass(opt_level=1)\n",
    "class MergeGraphTransform:\n",
    "    def __init__(self):\n",
    "        self.reset()\n",
    "        \n",
    "    def reset(self):\n",
    "        self.nodes = []\n",
    "\n",
    "    def transform_function(self, func, mod, ctx):\n",
    "        obj = self\n",
    "        class Replace(tvm.relay.ExprMutator):\n",
    "            def visit_function(self, fn):\n",
    "                new_params = [self.visit(x) for x in fn.params]\n",
    "                new_body = self.visit(fn.body)\n",
    "                new_body = QPartitionExpr(new_body).realize()\n",
    "                if new_params == list(fn.params) and new_body == fn.body:\n",
    "                    new_fn =  fn\n",
    "                else:\n",
    "                    new_fn = FunctionWithFields(fn, list(new_params), new_body)\n",
    "                obj.nodes.append(new_fn)\n",
    "                return new_fn\n",
    "        return Replace().visit(func)\n",
    "\n",
    "def make_add_subtract_multiply_pattern():\n",
    "    \"\"\"查找模式\n",
    "        (x + y)(x - y)\n",
    "    \"\"\"\n",
    "    x = is_var()\n",
    "    y = is_var()\n",
    "    node1 = is_op(\"add\")(x, y)\n",
    "    node2 = is_op(\"subtract\")(x, y)\n",
    "    node = is_op(\"multiply\")(node1, node2)\n",
    "    return node\n",
    "compiler_name = \"ccompiler\"\n",
    "pattern_table = [\n",
    "    (f\"{compiler_name}.add_subtract_multiply\", make_add_subtract_multiply_pattern()),\n",
    "]\n",
    "merge_passes = tvm.transform.Sequential([\n",
    "    relay.transform.MergeComposite(pattern_table),\n",
    "    # relay.transform.AnnotateTarget([compiler_name]),\n",
    "    relay.transform.PartitionGraph(),\n",
    "    # relay.transform.ToANormalForm()\n",
    "])\n",
    "run_mod = merge_passes(mod)\n",
    "run_mod.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #A2F\">@main</span>(<span style=\"color: #A2F; font-weight: bold\">%</span>x: Tensor[(<span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">*/</span>, <span style=\"color: #A2F; font-weight: bold\">%</span>y: Tensor[(<span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">*/</span>) <span style=\"color: #A2F; font-weight: bold\">-&gt;</span> Tensor[(<span style=\"color: #008000\">10</span>), float32] {\n",
       "  <span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">4</span> <span style=\"color: #A2F; font-weight: bold\">=</span> fn (<span style=\"color: #A2F; font-weight: bold\">%</span>FunctionVar_0_0: Tensor[(<span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">*/</span>, <span style=\"color: #A2F; font-weight: bold\">%</span>FunctionVar_0_1: Tensor[(<span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">*/</span>, PartitionedFromPattern<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;add_subtract_multiply_&quot;</span>, Composite<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;ccompiler.add_subtract_multiply&quot;</span>) <span style=\"color: #A2F; font-weight: bold\">-&gt;</span> Tensor[(<span style=\"color: #008000\">10</span>), float32] {\n",
       "    <span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">0</span> <span style=\"color: #A2F; font-weight: bold\">=</span> add(<span style=\"color: #A2F; font-weight: bold\">%</span>FunctionVar_0_0, <span style=\"color: #A2F; font-weight: bold\">%</span>FunctionVar_0_1) <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">*/</span>;\n",
       "    <span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">1</span> <span style=\"color: #A2F; font-weight: bold\">=</span> subtract(<span style=\"color: #A2F; font-weight: bold\">%</span>FunctionVar_0_0, <span style=\"color: #A2F; font-weight: bold\">%</span>FunctionVar_0_1) <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">*/</span>;\n",
       "    <span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">2</span> <span style=\"color: #A2F; font-weight: bold\">=</span> multiply(<span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">0</span>, <span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">1</span>) <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">*/</span>;\n",
       "    <span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">3</span> <span style=\"color: #A2F; font-weight: bold\">=</span> annotation<span style=\"color: #A2F; font-weight: bold\">.</span>cast_hint(<span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">2</span>, dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;int8&quot;</span>) <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">*/</span>;\n",
       "    annotation<span style=\"color: #A2F; font-weight: bold\">.</span>stop_fusion(<span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">3</span>) <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">*/</span>\n",
       "  } <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>fn (Tensor[(<span style=\"color: #008000\">10</span>), float32], Tensor[(<span style=\"color: #008000\">10</span>), float32]) <span style=\"color: #A2F; font-weight: bold\">-&gt;</span> Tensor[(<span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">*/</span>;\n",
       "  <span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">5</span> <span style=\"color: #A2F; font-weight: bold\">=</span> <span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">4</span>(<span style=\"color: #A2F; font-weight: bold\">%</span>x, <span style=\"color: #A2F; font-weight: bold\">%</span>y) <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">*/</span>;\n",
       "  <span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">6</span> <span style=\"color: #A2F; font-weight: bold\">=</span> exp(<span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">5</span>) <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">*/</span>;\n",
       "  <span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">7</span> <span style=\"color: #A2F; font-weight: bold\">=</span> annotation<span style=\"color: #A2F; font-weight: bold\">.</span>cast_hint(<span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">6</span>, dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;int8&quot;</span>) <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">*/</span>;\n",
       "  annotation<span style=\"color: #A2F; font-weight: bold\">.</span>stop_fusion(<span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">7</span>) <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">*/</span>\n",
       "}\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "transform = MergeGraphTransform()\n",
    "run_mod = transform(run_mod)\n",
    "run_mod.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "从数学角度来看，上述问题可以化简为 $f(x, y) = x^2 - y^2$："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tvm.relay.dataflow_pattern import DFPatternCallback\n",
    "\n",
    "class MergeGraphCallback(DFPatternCallback):\n",
    "    # A callback class to rewrite the matched pattern to a batch_norm op.\n",
    "    def __init__(self, require_type=False):\n",
    "        super().__init__(require_type)\n",
    "        self.pattern = make_add_subtract_multiply_pattern()\n",
    "\n",
    "    def callback(self, pre, post, node_map):\n",
    "        x = post.args[0].args[0] * post.args[0].args[0]\n",
    "        y = post.args[0].args[1] * post.args[0].args[1]\n",
    "        return x - y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "fn (%x: Tensor[(10), float32] /* ty=Tensor[(10), float32] */, %y: Tensor[(10), float32] /* ty=Tensor[(10), float32] */) -> Tensor[(10), float32] {\n",
       "  %0 = multiply(%x, %x);\n",
       "  %1 = multiply(%y, %y);\n",
       "  %2 = subtract(%0, %1);\n",
       "  %3 = annotation.cast_hint(%2, dtype=\"int8\") /* ty=Tensor[(10), float32] */;\n",
       "  %4 = annotation.stop_fusion(%3) /* ty=Tensor[(10), float32] */;\n",
       "  %5 = exp(%4) /* ty=Tensor[(10), float32] */;\n",
       "  %6 = annotation.cast_hint(%5, dtype=\"int8\") /* ty=Tensor[(10), float32] */;\n",
       "  annotation.stop_fusion(%6) /* ty=Tensor[(10), float32] */\n",
       "} /* ty=fn (Tensor[(10), float32], Tensor[(10), float32]) -> Tensor[(10), float32] */"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from tvm.relay.dataflow_pattern import rewrite\n",
    "\n",
    "rewrite(MergeGraphCallback(), relay.transform.DefuseOps()(run_mod)[\"main\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tvm-env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.2"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
